Skip to content

HealDA v2 Architecture#1758

Open
aayushg55 wants to merge 44 commits into
NVIDIA:mainfrom
aayushg55:ag/healda-v2-arch
Open

HealDA v2 Architecture#1758
aayushg55 wants to merge 44 commits into
NVIDIA:mainfrom
aayushg55:ag/healda-v2-arch

Conversation

@aayushg55

Copy link
Copy Markdown
Contributor

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

aayushg55 added 30 commits June 24, 2026 15:34
Bring the ragged grouped-query pixel cross-attention (pixel latents attend
to per-pixel packed observation tokens) into the experimental healda package
as a first step toward a video/observation DiT block.

Layout follows PNM's optional-dependency conventions:
- triton is referenced only via OptionalImport (never a bare import), so the
  modules import without triton and the import-linter external-import contract
  stays satisfied.
- the compiled Triton kernels live in a private _pixel_attn_kernels.py backend
  (mirroring the warp _warp_impl pattern), imported lazily by the public
  pixel_cross_attention.py only when triton is available -- no per-chunk guards.
- triton_autotune_cache.py: standalone autotune-config persistence util,
  likewise OptionalImport-based.

Headers use the current PNM template; ruff + license-header checks pass.

Commit message authored by AI
A 4D (b, t, x, c) analog of physicsnemo.nn.DiTBlock for field sequences: keeps
the DiT template (pluggable spatial-attention backend via get_attention,
adaLN-Zero conditioning, drop-path, gated MLP) and adds two optional gated
sub-layers:
- temporal (video) attention across time, with a pluggable time<->space reshard
  for context parallelism (manual all-to-all or ShardTensor.redistribute);
- observation cross-attention (pixel latents attend to packed per-pixel obs
  tokens via the vendored PixelCrossAttention).

New modules: temporal_attention.py (TemporalAttention + RoPE + causal/window
mask), sharding.py (shard_x/shard_t all-to-all + ShardTensor variants),
obs_packing.py (AttentionPacking/PixelGroupMap, copied as-is). Tests cover the
CPU plain + temporal paths and the CUDA full spatial+obs+temporal path
(forward/backward + grad flow).

Commit message authored by AI
Replace the _pixel_attn_kernels() helper + _k alias with an inline
`from . import _pixel_attn_kernels as kernels` inside the kernel-launch
wrappers, matching PNM's lazy optional-backend pattern (cf.
mesh/visualization/draw_mesh.py importing _matplotlib_impl/_pyvista_impl).

Commit message authored by AI
Bundle the observation tokens and ragged packing metadata into one
ObsCrossAttention object (tokens + cu_seqlens_k + max_seqlen_k + group_map)
so VideoDiTBlock's obs sub-layer takes a single argument instead of separate
obs_tokens + packing args. Drop the redundant data-pipeline AttentionPacking
struct (its counts/npix/hpx_level/pixel_order/is_packed fields are unused by
the model). Add jaxtyping Float/Int shape hints across the block, temporal
attention, sharding, and obs-packing modules.

Commit message authored by AI
Use the released physicsnemo.models.dit shape names: the token feature axis is
'hidden_size' (not 'channels'/'dim') and conditioning is 'condition_embed_dim',
matching DiTBlock's annotations and our own hidden_size constructor arg.

Commit message authored by AI
Add the ragged pixel cross-attention test suite (16 tests) validating the
Triton kernel forward/backward against a readable PyTorch GQA reference
(_ragged_gqa_reference), plus packed-grid, small-pixel grouping (grouped ==
ungrouped bit-for-bit), nn.Module wiring, empty-tokens DDP-safety, and config
validation. Port build_pixel_group_map (the CSR small-pixel grouping helper,
pure function of cu_seqlens_k) into obs_packing.py to support the grouping test.

Commit message authored by AI
A HEALPix field-sequence diffusion transformer over (B, C, T, npix): reuses the
existing HEALPixPatchTokenizer/HEALPixPatchDetokenizer (which already fold the
time axis and add the calendar embedding) and EDM conditioning, reshapes the
flat token sequence to (B, T, X, hidden) for the VideoDiTBlock stack (spatial +
optional factorized temporal + optional observation cross-attention), then back
for the detokenizer. Observations enter as a prebuilt ObsCrossAttention bundle.

Tests: CPU dense+temporal and CUDA dense+temporal+obs forward/backward.

Commit message authored by AI
- VideoDiT no longer hardcodes HEALPix: it takes a pluggable tokenizer /
  detokenizer (the grid lives only in tokenization; the backbone is grid-
  agnostic) plus time_length, threading tokenizer_kwargs and reshaping the flat
  time-major token sequence to (B, T, X, hidden) for the blocks.
- VideoDiTBlock's spatial+MLP path now mirrors physicsnemo.nn.DiTBlock: renamed
  emb_channels -> condition_embed_dim and added the DiTBlock dropout args
  (attn_drop_rate, proj_drop_rate, mlp_drop_rate, final_mlp_dropout), wired into
  get_attention and the MLP. Its args are now a superset of DiTBlock's.

Commit message authored by AI
…pose VideoDiT

- VideoDiTBlock subclasses physicsnemo.nn.DiTBlock, reusing its spatial
  attention, gated MLP, pre-norms, 6-chunk adaLN-Zero and drop-path; adds
  optional gated temporal + observation cross-attention sub-layers.
- is_causal and obs/temporal config moved to __init__; obs args behind
  obs_kwargs, temporal behind temporal_kwargs (explicit Dicts, no **kwargs).
- VideoDiT inherits physicsnemo.Module and composes the conditioning embedder +
  pluggable tokenizer/detokenizer + blocks (no production DiT change).
- sharding: @torch._dynamo.disable on the ShardTensor reshard.

Commit message authored by AI
…-attention

Generalize the experimental video DiT into a grid-agnostic DiT with a time axis:
- add a shared ndim-agnostic AdaLayerNormZero with a zero_init toggle (using
  get_layer_norm) and SiLU inside the module;
- compose VideoDiTBlock from shared building blocks instead of subclassing
  DiTBlock, moving modulation into per-sub-layer AdaLayerNormZero;
- replace the obs-specific cross-attention with a generic pluggable
  CrossAttentionModuleBase slot + opaque context; PixelCrossAttention is the
  reference impl and now owns the fold/ragged-unpack;
- make VideoDiT a kwarg-superset of DiT (drop_path_rates, conditioning_embedder
  choice, attn_kwargs, block_kwargs, dit_initialization).

Commit message authored by AI
…oss-attn, adaLN naming

- VideoDiT/VideoDiTBlock: cross_attention is a per-block factory (no deepcopy);
  forward context typed Optional[Any]; conditioning resolution inlined.
- Time axis first-class: HEALPix tokenizer gains separate_time_axis; detokenizer
  infers flat-vs-time-first from input rank (backwards-compatible with the v1
  flat + time_length path).
- adaLN-Zero attrs renamed role-first: attn_norm / temporal_attn_norm /
  cross_attn_norm / mlp_norm.

Commit message authored by AI
… Triton kernel

Port the v2 FiLM-conditioned observation tokenizer as an initial drop-in:
- obs_film_tokenizer.py: ObsTokenizerFiLM module + pure-PyTorch reference,
  custom-op wrappers, and the fused_film_tokenizer_triton entry point. Forward
  dispatches to the Triton kernel on CUDA when triton is available, else the
  reference path.
- _film_kernels.py: private fused FiLM forward/backward Triton kernels, guarded
  by OptionalImport("triton") (no bare import) and imported lazily.
- test_obs_film_tokenizer.py: CPU reference smoke test plus CUDA Triton-vs-
  reference parity tests.

Settings are kept as-is for now (TODO(polish) markers on unused ones).

Commit message authored by AI
DiT.__init__ indexed input_size[1] unconditionally, which IndexErrors for
non-2D tokenizers (e.g. HEALPix, where HealDA passes a 1-tuple input_size).
The latent grid is only consumed by the NATTEN backends, so guard it.

Commit message authored by AI
Match the public module stem (obs_film_tokenizer.py), mirroring
pixel_cross_attention.py / _pixel_attn_kernels.py.

Commit message authored by AI
…kenizer

obs_film_tokenizer.py -> obs_tokenizer.py, _obs_film_kernels.py ->
_obs_tokenizer_kernels.py (class stays ObsTokenizerFiLM; "FiLM" distinguishes
the impl from the existing ObsTokenizer and leaves room to evolve).

Commit message authored by AI
Compose the upstreamed VideoDiT backbone with the FiLM ObsTokenizerFiLM and
per-block PixelCrossAttention into the production v2 video+obs data-assimilation
architecture (hidden 1536, 16 heads, 32 layers, time_length 8, linear causal
temporal attention, drop-path 0.0 for the first 4 blocks then 0.1), so the
existing healda checkpoint can be loaded. HealDAv2 hosts the FiLM obs tokenizer
and assembles the per-pixel ObsCrossAttention context the backbone consumes.

Commit message authored by AI
…tidy wording

- VideoDiT.set_context_parallel(mode, target) fans the temporal time<->space
  reshard config out to every block (the per-block setter had no caller, so CP
  was never actually enabled through the model).
- VideoDiT.forward asserts the tokenizer emits 4D (B,T,X,hidden).
- Drop the coined "field sequence(s)" phrasing from docstrings/comments.

Commit message authored by AI
…ntract

Make the base a general cross-attention sub-layer rather than one "injected
into a video DiT block": forward takes (*batch, hidden_size) latents and an
opaque context: Any. Drops the field-sequence wording and trims the docstring.

Commit message authored by AI
…ocks=2)

Replace the separate attn_norm + mlp_norm AdaLayerNormZero(n_blocks=1) pair with
a single norm1 = AdaLayerNormZero(n_blocks=2) that emits the spatial-attention
modulation plus the raw MLP shift/scale/gate, and a parameter-free LayerNorm
MLP pre-norm modulated by those. This matches the DiT/diffusers layout and the
production checkpoint's single norm1.linear. temporal_attn_norm/cross_attn_norm
stay one-block adaLNs; initialize_weights now zeroes norm1.

Commit message authored by AI
Rename the packed-observation container to ObsContext and make tokens optional
(unset until the observation tokenizer fills it). PixelCrossAttention.forward
now consumes an ObsContext, raising if tokens is unset; a None group_map keeps
using the ungrouped ragged path (the model layer does not build it). Updates
all references and tests.

Commit message authored by AI
…orward

ObsContext now also carries the raw per-observation arrays (values,
float_metadata, obs_type, channel, platform) alongside the ragged packing, so
HealDAv2.forward takes one obs: ObsContext instead of loose per-obs and packing
args. The forward runs the tokenizer, fills tokens via dataclasses.replace, and
passes the context through; it no longer builds the pixel group map (a None
group_map uses the ungrouped ragged path). build_pixel_group_map stays for
callers that precompute it.

Commit message authored by AI
…mporal RoPE

Replace the hand-rolled RotaryPositionEmbedding with PhysicsNeMo's
nn.module.rope.RotaryPositionEmbedding1D (math-verified equivalent: same
interleaved-pair rotation and theta^(-2k/d) schedule). The temporal q/k are
transposed so the time axis is the -2 dim the module rotates, then restored.
Its cos/sin are non-persistent buffers recomputed at init, so the old
persistent rope.freqs_cos/sin keys no longer exist.

Commit message authored by AI
…rim docstring

Remove the unused nchannel, nplatform, and use_global_channel_platform_ids
constructor params (the embedding tables are always sized GLOBAL_MAX_* and id
spaces are the caller's responsibility). Trim the editorial rationale and
conv/sat first-linear essays in the class docstring to terse facts.

Commit message authored by AI
…meter-free

HealDAv2 passes attn_kwargs={"qk_norm_type":"RMSNorm","qk_norm_affine":False}.
PhysicsNeMo's TimmSelfAttention declares those exact kwargs (translating them to
timm's qk_norm + RmsNorm norm_layer), so affine-free RMSNorm actually engages and
the names are not silently swallowed. Add a regression test that the spatial q/k
norm modules exist (not Identity) and have no learnable parameters.

Commit message authored by AI
Lightweight audit cleanups: jaxtyping annotations on the two HEALPix hpx
tokenizer forwards, a descriptive CalendarEmbedding ValueError message, an
accurate uneven-shards comment in sharding.py, a Literal["apex","torch"] hint on
AdaLayerNormZero.layernorm_backend, and removal of the dead VideoDiTBlock
self.hidden_size attribute. The hpx __init__ **kwargs safety nets are kept.

Commit message authored by AI
…n; tidy HealDAv2 wiring

Drop the per-block hot-path validation from PixelCrossAttention.forward (tokens-set
and cu_seqlens-vs-pixel-count) and instead validate the packing's structural shape
once in ObsContext.__post_init__; tokens-set is guaranteed by HealDAv2.forward.
Build the HEALPix (de)tokenizers as locals instead of inline in the VideoDiT call,
and note the drop-path zero-first stability rationale in the docstring.

Commit message authored by AI
…packing->obs_context

Extract the kernel-companion packing primitives into a single
pixel_attention_utils.py: sort_and_pack (inlined Triton counting-sort kernel,
guarded by triton availability, argsort fallback), counts_to_cu_seqlens, and
build_pixel_group_map. These operate on plain index/count tensors, so the data
pipeline builds the packing the model only consumes. obs_packing.py is now purely
the ObsContext + PixelGroupMap contract, renamed to obs_context.py. Also build
HealDAv2's cross-attention via functools.partial and keep the cross-attention
base docstring generic. Adds CPU parity tests for the packing utils.

Commit message authored by AI
Comment thread physicsnemo/experimental/models/healda/pixel_attention_utils.py Outdated
# Directory for the persisted Triton autotune cache. Overridable so deployments
# can redirect it to a fast/shared location; defaults under the user cache dir.
_AUTOTUNE_CACHE_DIR = os.environ.get(
"PHYSICSNEMO_CACHE_DIR", os.path.expanduser("~/.cache/physicsnemo")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a standard/documented env var for the package. Can we just rely on whatever Triton defaults to (both location and env var controlling)?

if self.use_rope:
self.rope = RotaryPositionEmbedding1D(
head_dim=self.head_dim, max_seq_len=max_seq_len, theta=rope_base
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up I'm thinking of refactoring these modules slightly so they are better for multi-block models (as most models are) -- instead of the embedding tables being rebuilt per attention block the rope module can be instead be created once as a submodule of the top-level network, and then the same embedding tables there can be reused in all blocks. Should be a small/minor refactor

from older checkpoints; set False for correct behaviour.
causal_window: When set (and the forward pass is causal), restrict each
frame to attend to itself and the previous ``causal_window - 1`` frames.
"""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make into standard docstring format (see CODING_STANDARDS). Also this can probably be moved into the video_dit_block.py file to reduce the file count

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will update to match the coding standards. I'd lean toward keeping temporal attention in its own module since the TemporalAttention module is quite separate from the VideoDitBlock and could be used as a more general module. But if you'd rather co-locate it, your call.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, on the consolidation, how about the following:

  1. Kernels + related utils in healda/kernels/ and video_dit_block added to video_dit as per other comment
  2. temporal_attention, cross_attention, and pixel_cross_attention merged into one attention_layers.py

Imo a singular ~1k line file with all the attention utils is better than disparate files

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that works

with open(tmp, "w") as f:
json.dump(data, f, indent=2)
os.replace(tmp, path)
return n

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe prepend a healda (or configurable) name to the paths to avoid possible mixing/pollution with other similarly named kernels? Not too familiar with triton's behavior here but seems like generally better practice to try and keep things self-contained if we are manipulating things in the cache. And there should be a warning to users in the docstrings of any modules that could trigger this behavior

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea agree that there might be a better way to do things. For context, the point of this cache is to cache the best kernel config from the autotuning sweep (as opposed to the generated kernel itself, whose caching triton internally controls and can be set through some env variables). This caching is more relevant for training as due to the dynamic obs count and each rank seeing a different number of obs at any given iter, they will be running the autotune at different points in the training, but all ranks are forced to stall if any stalls. Furthermore, across training runs, the autotuning is rerun, so this cost is normally payed every training job.

In contrast, this tries to just pay this cost upfront once and cache the autotuning result.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, how about we just make this behavior clearly documented in docstrings, and make it opt-in? i.e. unless user sets some env var specific to healda, no hidden/automatic cache behavior will happen

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest trying to consolidate all triton kernels and related utils into one file, there's a lot of different files here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed there are many files under the healda/ folder and the grouping may be currently unclear, but there are a few different unrelated triton kernels involved and clumping them into a single multi-thousand line file makes it harder to follow.

I propose creating a a healda/kernels/ subdir with one file per op, and also moving the autograd glue/host dispatch code out of the current 1k-llne module files such that each module file is just the torch model layer and easier to follow without the implementation complexity.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also fold the VideoDitBlock and VideoDiT into the same file to reduce file count

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yeah I think a healda/kernels/ should help, let's go with that. Consolidating VideoDiTBlock also sounds good

Comment thread physicsnemo/experimental/models/healda/healda_v2.py
@negin513 negin513 self-requested a review June 30, 2026 15:13
@negin513

Copy link
Copy Markdown
Member

This PR adds a substantial new v2 model stack but doesn't surface any of the building blocks through the package's public API. I think it might worth exposing these similar to v1 through the public API.

@aayushg55

Copy link
Copy Markdown
Contributor Author

@aayushg55 I think VideoDiTBlock is fine as a standalone block, subclassing DiTBlock seems like it'd get messy. However a more general concern is I thought we were replacing the v1 healda architecture with this update (not supporting the older model/checkpoints), it looks like instead you've added the v2 in alongside and kept v1?

Sounds good. I wasn't sure if we wanted to be removing the v1 architecture too. But given that the v2 architecture is strictly better and the old architecture can be found in older PNM / my previous PR branch for reproducibility, keeping v1 would add dead code, so I can delete it.

@aayushg55

Copy link
Copy Markdown
Contributor Author

This PR adds a substantial new v2 model stack but doesn't surface any of the building blocks through the package's public API. I think it might worth exposing these similar to v1 through the public API.

What would this look like?

aayushg55 and others added 6 commits June 30, 2026 15:47
…pers

Split the adaLN-Zero op into a projection module and two apply helpers so
every sub-layer composes the same way and the pieces can be reused across
DiT/video blocks (3D and 4D states):

- adaln.py: AdaLNModulation (c -> 3*n_blocks shift/scale/gate chunks) plus
  standalone modulate() and gated_residual(); the affine-free pre-norm now
  lives at the call site instead of inside the module.
- video_dit_block.py: rewire to the uniform norm -> modulate -> gated_residual
  pattern (no block-0 special case); projections named norm1_modulation /
  temporal_attn_modulation / cross_attn_modulation, each attention sub-layer
  owning its own parameter-free pre-norm (attn_norm / mlp_norm /
  temporal_attn_norm / cross_attn_norm).
- video_dit.py: update the adaLN docstring reference.
- test: track the norm1_modulation rename.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…te _kernels modules

Consolidate each op's Triton device code and its host-side glue (launch
dispatch, autograd Function, custom-op registration) into the op's private
_*_kernels module, leaving the model files as just the nn.Module + public API:

- _pixel_attn_kernels.py / _obs_tokenizer_kernels.py: now own the full Triton
  op stack instead of only the @triton.jit kernels.
- pixel_cross_attention.py: slimmed to the PixelCrossAttention module;
  pixel_attention_reference -> _pixel_attention_reference (now private).
- obs_tokenizer.py: slimmed to the ObsTokenizerFiLM module.
- __init__.py: add package docstring + __all__; export PixelCrossAttention and
  ObsTokenizerFiLM alongside the existing public classes.
- pixel_attention_utils.py: docstring trim.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The packing helpers (sort_and_pack, counts_to_cu_seqlens, build_pixel_group_map,
counting_sort_and_pack + its counting-sort Triton kernel) build the ragged
layout that pixel_attention consumes and are used only by pixel cross-attention,
so co-locate them in pixel_cross_attention.py and drop the separate module.
Repoint the two tests and update the obs_context doc reference.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Comment thread physicsnemo/experimental/models/healda/pixel_cross_attention.py Outdated
root and others added 2 commits July 1, 2026 09:54
- video_dit_block: set_context_parallel now type-checks target (ProcessGroup for
  "all_to_all", DeviceMesh for "shardtensor") with a clear TypeError.
- healda_v2: add a set_context_parallel passthrough to the backbone; expand the
  class docstring (data-flow stages, grid-agnostic boundary, context-parallel
  constraints) and Notes (obs packing, and that only the timm backend supports
  the RMSNorm / qk_norm_affine=False QK-norm used for stable training).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@negin513 negin513 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @aayushg55!

I’ll add a few comments inline.

One higher-level thing I noticed: this PR does not add an end-to-end training recipe or example showing how the new HealDAv2, VideoDiT, ObsContext, and observation/pixel-attention pieces are intended to be wired together...

Does it make sense to add a minimal example in this PR under examples/weather/healda/ ?

@aayushg55

Copy link
Copy Markdown
Contributor Author

One higher-level thing I noticed: this PR does not add an end-to-end training recipe or example showing how the new HealDAv2, VideoDiT, ObsContext, and observation/pixel-attention pieces are intended to be wired together...

Does it make sense to add a minimal example in this PR under examples/weather/healda/ ?

I think the idea was for this PR to only introduce the architecture to get it in earlier rather than later and enable adding inference capability in Earth2Studio before the next release. Further integrating the training loop and making sure it is up-to-standard would likely take significant engineering effort (and closer collaboration with the PNM team), as our internal training loops/dataloaders are not using PNM in any form at the moment and do not follow how existing PNM training loops are set up. Once the dataloader pieces are also added, we can follow up with the training loop.

Collapse PixelCrossAttention's input_dim/output_dim into one hidden_size (it
is a residual sub-layer that always used equal widths) and fix the output
reshape. Rename TemporalAttention embed_dim -> hidden_size and update callers.

Commit message authored by AI
@negin513

negin513 commented Jul 1, 2026

Copy link
Copy Markdown
Member

This PR adds a substantial new v2 model stack but doesn't surface any of the building blocks through the package's public API. I think it might worth exposing these similar to v1 through the public API.

What would this look like?

I was thinking we could update physicsnemo/experimental/models/healda/__init__.py to export the main user-facing v2 pieces, similar to how v1 exports HealDA.

At minimum, probably:

from .healda_v2 import HealDAv2, HealDAv2MetaData
from .obs_context import ObsContext, PixelGroupMap

Similar to HealDA-v1 stuff. But I would keep the Triton kernel modules private. The main goal is that a recipe or user can do from physicsnemo.experimental.models.healda import HealDAv2, ObsContext without importing deep module paths.

@aayushg55

Copy link
Copy Markdown
Contributor Author

This PR adds a substantial new v2 model stack but doesn't surface any of the building blocks through the package's public API. I think it might worth exposing these similar to v1 through the public API.

What would this look like?

I was thinking we could update physicsnemo/experimental/models/healda/__init__.py to export the main user-facing v2 pieces, similar to how v1 exports HealDA.

At minimum, probably:

from .healda_v2 import HealDAv2, HealDAv2MetaData
from .obs_context import ObsContext, PixelGroupMap

Similar to HealDA-v1 stuff. But I would keep the Triton kernel modules private. The main goal is that a recipe or user can do from physicsnemo.experimental.models.healda import HealDAv2, ObsContext without importing deep module paths.

Thanks, I have updated the physicsnemo/experimental/models/healda/init.py to export the v2 pieces

"UniformFusion",
"ScatterAggregator",
"scatter_mean",
]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely don't need to publicly export all of these (e.g. MetaData classes are not typically exported), and I would actually push for the majority of these to not be exported. Advanced users who want to go in and access them can use a direct import but otherwise we only need to export the HealDA architecture and maybe any components that could be useful for people working with custom obs data.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And a more major change, but one I think is worth doing, is to outright remove the healda v1 architecture and replace it with this newer/better version. We decided we are not going to support the legacy v1 model (users can still access it by installing an older version of phyiscsnemo+earth2studio), and planned for it by version-capping the physicsnemo source for healda in last earth2studio release. So we should be free to remove/replace it here.

Ultimately the "v2" is not very meaningful for most external users and I think it's more clear to have one class for the architecture. We don't need to wait a release cycle to deprecate the older one since it's all in experimental

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @pzharrington, thanks for the context. The MetaData and HealDAv2 was my suggestion. I was mostly thinking about discoverability from the recipe/user side on useful utilities such as ObsContext, but I agree we should keep the public API tighter and not export internals unnecessarily.

Given your point that we do not plan to support the legacy v1 model, I agree that replacing the public HealDA export with the newer implementation is cleaner than exposing a parallel HealDAv2. Then we can keep the API focused on the architecture itself, plus only the obs-data pieces that recipes/custom data users actually need, e.g. ObsContext (?) if users are expected to construct those directly. @pzharrington, I 100% follow your judgment on this. + @aayushg55 maybe you can comment around this what utils from HEALDAv2 will be helpful for future to expose through API.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cross attention module (PixelCrossAttention), tokenizer (ObsTokenizerFiLM), the obs wrapper (ObsContext), the combined architecture (HealDAv2 renamed to HealDA), and probably the VideoDiT would be things users might be interested in using.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's go with that then

:class:`physicsnemo.nn.module.hpx.tokenizer.HEALPixPatchTokenizer`.
2. A :class:`.video_dit.VideoDiT` backbone processes the token sequence with
spatial attention, factorized temporal attention, and adaLN-Zero
conditioning built from the EDM noise embedding and the calendar

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDM noise embedding? This isn't a diffusion model so that's a bit confusing

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The architecture uses all of the DiT architecture conditioning components (AdaLN + noise/condition label). We trained it as a regression model (huber loss) setting noise to always be 0, but it could be trained as a diffusion model too.

@pzharrington pzharrington Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I suppose that's true. I'm just not sure about advertising that here quite yet -- partly since no one has evaluated/experimented with the model in that realm as far as I know, but mainly because as of now making it a diffusion model would have to be a manual implementation. The forward signature expected by all physicsnemo.diffusion components is (x, t, condition: torch.Tensor | TensorDict) (see here) so it would take a bit of work to massage the the current Healdav2 into compliance, or one would have to write their own diffusion loop from scratch (which we don't necessarily want to encourage for obvious reasons)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I can revise the docstring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants